#!/usr/bin/env python3
# -*- coding:UTF-8 -*-

import AutoExecUtils
import os
import traceback
import datetime
import argparse
import sys
import json
from bson.json_util import dumps, loads
from DSLTools import Parser, Interpreter

# inspect defined for app system cache
appSystemParsedCache = {}


def usage():
    pname = os.path.basename(__file__)
    print(pname + " --outputfile <path> ")
    exit(1)


def getObjCatPkDef(objCat, customPkDef):
    pkDefs = {
        "INS": ["MGMT_IP", "PORT"],
        "DB": ["PRIMARY_IP", "PORT", "NAME"],
        "CLUSTER": ["PRIMARY_IP", "PORT", "NAME"],
        "DBINS": ["MGMT_IP", "PORT", "INSTANCE_NAME"],
        "OS": ["MGMT_IP"],
        "HOST": ["MGMT_IP", "BOARD_SERIAL"],
        "NETDEV": ["MGMT_IP", "SN"],
        "SECDEV": ["MGMT_IP", "SN"],
        "SERVERDEV": ["MGMT_IP"],
        "VIRTUALIZED": ["MGMT_IP"],
        "SWITCH": ["MGMT_IP", "SN"],
        "FIREWALL": ["MGMT_IP", "SN"],
        "LOADBALANCER": ["MGMT_IP", "SN"],
        "STORAGE": ["MGMT_IP", "SN"],
        "FCSWITCH": ["MGMT_IP", "SN"],
        "K8S": ["MGMT_IP"],
        "CONTAINER": ["MGMT_IP", "CONTAINER_ID"],
        "UNKNOWN": ["MGMT_IP"],
    }
    fixPkDef = pkDefs.get(objCat)
    if fixPkDef is not None:
        return fixPkDef

    return customPkDef


def getLevelNo(level):
    levelNo = 0
    if level.upper() == "WARN":
        levelNo = 1
    elif level.upper() == "CRITICAL":
        levelNo = 2
    elif level.upper() == "FATAL":
        levelNo = 3
    return levelNo


def getAlertSummarized(alerts):
    totalCount = 0
    gLevel = "NORMAL"
    alertOutline = {}
    alertFields = {}
    for alert in alerts:
        totalCount = totalCount + 1
        level = alert["ruleLevel"]
        if level not in alertOutline:
            alertOutline[level] = 1
        else:
            alertOutline[level] = alertOutline[level] + 1
        if getLevelNo(gLevel) < getLevelNo(level):
            gLevel = level

        jsonPath = alert["jsonPath"]
        alertField = None
        if jsonPath not in alertFields:
            alertField = {"alertLevel": alert["ruleLevel"], "alertField": jsonPath, "ruleSeqs": [alert["ruleSeq"]], "ruleNames": [alert["ruleName"]]}
            alertFields[jsonPath] = alertField
        else:
            alertField = alertFields[jsonPath]
            ruleLevel = alert["ruleLevel"]
            if getLevelNo(alertField["alertLevel"]) < getLevelNo(ruleLevel):
                alertField["alertLevel"] = ruleLevel
            alertField["ruleSeqs"].append(alert["ruleSeq"])
            alertField["ruleNames"].append(alert["ruleName"])

    alertFieldsArray = list(alertFields.values())
    alertSummarized = {}
    alertSummarized["totalCount"] = totalCount
    alertSummarized["status"] = gLevel

    return {"alertOutline": alertOutline, "alertSummarized": alertSummarized, "alertFields": alertFieldsArray}


def getAppMetrisDef(db, appSysIds, defName, parentRuleMap):
    if not appSysIds:
        return []

    newAppDefs = []
    queryAppSysIds = []
    for appSysId in appSysIds:
        cachedAppDef = appSystemParsedCache.get(appSysId)
        if cachedAppDef is not None:
            newAppDefs.append(cachedAppDef)
        else:
            queryAppSysIds.append(appSysId)

    appDefCollection = db["_inspectdef_app"]
    # find appsystem custome threshold
    appDefs = []
    if queryAppSysIds:
        appDefs = appDefCollection.find({"appSystemId": {"$in": appSysIds}, "name": defName})

    for appDef in appDefs:
        # 比较AST抽象语法树结构生成时间根规则修改时间，如果老了则进行AST分析并保存和更新到DB
        ruleSeq = 5000
        ruleNeedUpdate = False
        currentTime = datetime.datetime.utcnow()
        for ruleDef in appDef["thresholds"]:
            if ruleDef.get("isOverWrite") != 1:
                ruleSeq = ruleSeq + 1
                ruleDef["ruleSeq"] = "%s#%d" % (appDef.get("appSystemAbbrName"), ruleSeq)
            else:
                parentRule = parentRuleMap.get(ruleDef["ruleUuid"])
                ruleDef["ruleSeq"] = parentRule.get("ruleSeq")
            ruleDef["appSystemId"] = appDef["appSystemId"]

            updateTime = None
            if updateTime in ruleDef:
                ruleDef["_updatetime"]
            else:
                updateTime = currentTime

            astGenTime = None
            if astGenTime in ruleDef:
                astGenTime = ruleDef["_astgentime"]

            if astGenTime is None or updateTime is None or astGenTime < updateTime:
                ruleDef["_astgentime"] = currentTime
                ast = Parser(ruleDef["rule"])
                ruleDef["AST"] = ast.asList()
                # print(json.dumps(ast.asList(), sort_keys=True, indent=4))
                ruleNeedUpdate = True

        if ruleNeedUpdate:
            _id = appDef["_id"]
            del appDef["_id"]
            appDefCollection.update_one({"_id": _id}, {"$set": {"thresholds": appDef["thresholds"]}})
        newAppDefs.append(appDef)

    return newAppDefs


def getMetricsDef(db, objCat, objType):
    inspectDefCollection = db["_inspectdef"]
    collectName = "COLLECT_" + objCat.upper()
    inspectDef = inspectDefCollection.find_one({"collection": collectName, "filter._OBJ_TYPE": objType})
    if inspectDef is None:
        inspectDef = inspectDefCollection.find_one({"collection": collectName, "filter._OBJ_TYPE": None})

    if inspectDef is None:
        return {"name": "empty", "fields": [], "thresholds": []}

    # 比较AST抽象语法树结构生成时间根规则修改时间，如果老了则进行AST分析并保存和更新到DB
    ruleMap = {}
    ruleSeq = 0
    ruleNeedUpdate = False
    currentTime = datetime.datetime.utcnow()
    for ruleDef in inspectDef["thresholds"]:
        ruleSeq = ruleSeq + 1
        ruleDef["ruleSeq"] = "#%d" % (ruleSeq)
        ruleMap[ruleDef.get("ruleUuid", 0)] = ruleDef

        updateTime = None
        if updateTime in ruleDef:
            ruleDef["_updatetime"]
        else:
            updateTime = currentTime

        astGenTime = None
        if astGenTime in ruleDef:
            astGenTime = ruleDef["_astgentime"]

        if astGenTime is None or updateTime is None or astGenTime < updateTime:
            ruleDef["_astgentime"] = currentTime
            ast = Parser(ruleDef["rule"])
            ruleDef["AST"] = ast.asList()
            # print(json.dumps(ast.asList(), sort_keys=True, indent=4))
            ruleNeedUpdate = True

    if ruleNeedUpdate:
        _id = inspectDef["_id"]
        del inspectDef["_id"]
        inspectDefCollection.update_one({"_id": _id}, {"$set": {"thresholds": inspectDef["thresholds"]}})

    return inspectDef


# 获取对应CMDB对象全量数据


def getObjCatFullData(datalist):
    new_datalist = []
    for data in datalist:
        objCat = data.get("_OBJ_CATEGORY")
        objType = data.get("_OBJ_TYPE")
        pkDef = data.get("PK")
        pkDef = getObjCatPkDef(objCat, pkDef)
        table = "COLLECT_" + objCat

        typeName = objCat
        if objType is not None:
            typeName = objType

        pkInvalid = False
        # 根据PK的定义生成Primary Key filter
        primaryKey = {}
        for pKey in pkDef:
            if pKey in data:
                pVal = data[pKey]
                primaryKey[pKey] = data[pKey]
                if pVal is None or pVal == "":
                    print("WARN: {} PK attribute:{} is empty.".format(typeName, pKey))
            else:
                primaryKey[pKey] = None
                pkInvalid = True
                print("WARN: {} not contain PK attribute:{}.".format(typeName, pKey))

        if pkInvalid:
            new_datalist.append(data)

        collection = db[table]
        fullData = collection.find_one(primaryKey, {"_id": False})
        if fullData is None:
            fullData = data
        else:
            for key in data.keys():
                fullData[key] = data[key]
        new_datalist.append(fullData)
    return new_datalist


def getInspectData(nodeInfo):
    ip = nodeInfo.get("host")
    port = nodeInfo.get("port")
    nodeName = nodeInfo.get("nodeName")

    dataRecords = []
    inspectData = {}
    nodeData = {}
    # 加载上一个cmdb收集步骤的output数据，CMDB采集可能返回多条记录，需要抽取匹配的记录
    outputData = AutoExecUtils.loadNodeOutput()
    if outputData is None:
        print("WARN: Node output data is empty.")
    else:
        collectData = []
        for key, value in outputData.items():
            if key.startswith("inspect/"):
                # 这一步是抽取可用性检测插件（pingcheck,urlcheck,sshcheck,agentcheck,tcpcheck）的返回结果
                aInspectData = value.get("DATA")
                if aInspectData is not None:
                    inspectData.update(aInspectData)
                continue
            elif key.startswith("svcinspect/"):
                aInspectData = value.get("DATA")
                # svcinspect服务检测插件，包含CMDB采集数据（数据）和可用性检测结果（Map对象），针对svcinspect特殊处理
                if isinstance(value.get("DATA"), list):
                    aInspectData = aInspectData[0]

                if aInspectData is not None:
                    inspectData.update(aInspectData)

                if not isinstance(value.get("DATA"), list):
                    continue

                # svcinspect插件是对CMDB采集插件数据的补充，下面告警阀值是对match全量采集数据的，这里对增量的采集数据进行数据补全
                value["DATA"] = getObjCatFullData(value.get("DATA"))

            elif not key.startswith("cmdbcollect/"):
                continue

            collectData = value["DATA"]
            if collectData is None:
                print("WARN: Plugin {} did not return collect data.".format(key))
                continue

        if not collectData:
            print("WARN: Can not find collect data in node output.")

        if "AVAILABILITY" in inspectData:
            if inspectData["AVAILABILITY"] == 0:
                inspectData["AVAILABILITY"] = False
            else:
                inspectData["AVAILABILITY"] = True
        else:
            inspectData["AVAILABILITY"] = False
            errMsg = "AVAILABILITY attribute not found in inspect data, must use availability tools(pingcheck,urlcheck,sshcheck,agentcheck,tcpcheck...) to collect availability first."
            inspectData["ERROR_MESSAGE"] = errMsg

        # 标记数据集合类型
        objCatMap = {}
        for data in collectData:
            objCat = data.get("_OBJ_CATEGORY")
            objCatMap[objCat] = 1

        for data in collectData:
            isMalformData = 0
            objCat = data.get("_OBJ_CATEGORY")
            if objCat is None:
                isMalformData = 1
                print("WARN: Data not defined _OBJ_CATEGORY.")
            if isMalformData == 1:
                print(json.dumps(data))
                continue

            # 如果数据的归属类别跟定义和节点信息一致，则抽取并返回数据记录
            if data.get("MGMT_IP") == ip or data.get("PRIMARY_IP") == ip:
                if port is not None:
                    if "PORT" not in data:
                        print("WARN: Data has no port.")
                        continue

                    if int(data["PORT"]) != port:
                        print("WARN: Data not match port:{} but:{}.".format(port, data["PORT"]))
                        continue

                if objCat == "DBINS" and data.get("INSTANCE_NAME") != nodeName:
                    continue
                elif objCat == "DB" and data.get("NAME") != nodeName:
                    continue
                elif "OS" in objCatMap and objCat == "HOST":
                    continue

                print("INFO: Data matched {}/{}:{}/{}.".format(objCat, ip, port, nodeName))
                nodeData = data
                if "RESOURCE_ID" in nodeData:
                    inspectData["RESOURCE_ID"] = nodeData["RESOURCE_ID"]
                nodeData.update(inspectData)
                dataRecords.append(nodeData)

            else:
                print("WARN: Data not match IP:{} but:{}.".format(ip, data["MGMT_IP"]))

        if not dataRecords:
            print("WARN: Inspect data is empty, set _OBJ_CATEGORY to EMPTY.")
            nodeData = inspectData
            nodeData["_OBJ_CATEGORY"] = "EMPTY"
            nodeData["_OBJ_TYPE"] = "EMPTY"
            dataRecords.append(nodeData)

        return dataRecords


def inspectData(nodeInfo):
    rptDataRecords = []
    execUser = os.getenv("AUTOEXEC_USER")
    inspectDefMap = {}
    objCatFieldsMap = {}
    # 获取cmdb对单个节点采集的数据
    dataRecords = getInspectData(nodeInfo)
    for dataCollected in dataRecords:
        if dataCollected is not None:
            inspectDef = {}
            inspectRulesMap = {}

            availability = dataCollected["AVAILABILITY"]

            objCat = dataCollected["_OBJ_CATEGORY"]
            objType = dataCollected["_OBJ_TYPE"]
            if objCat not in inspectDefMap:
                inspectDef = getMetricsDef(db, objCat, objType)
                inspectDefMap[objCat] = inspectDef

                needFieldsMap = {"AVAILABILITY": 1, "ERROR_MESSAGE": 1, "RESPONSE_TIME": 1}
                objCatFieldsMap[objCat] = needFieldsMap
                for fieldDef in inspectDef["fields"]:
                    if fieldDef["selected"] == 1:
                        needFieldsMap[fieldDef["name"]] = 1

            resourceId = dataCollected.get("RESOURCE_ID")
            if resourceId is None or resourceId == "0":
                resourceId = getResourceId(dataCollected)
                if resourceId is None:
                    print("WARN: Can not find resource Id for {}/{}:{}/{}.".format(objCat, dataCollected.get("MGMT_IP"), dataCollected.get("PORT"), dataCollected.get("NAME")))
                    continue

            if resourceId is None:
                resourceId = nodeInfo["resourceId"]

            appSysIds = nodeInfo.get("appSystemId", None)
            defName = inspectDef["name"]
            thresholdsMap = {}
            for ruleDef in inspectDef["thresholds"]:
                thresholdsMap[ruleDef.get("ruleUuid", 0)] = ruleDef
                inspectRulesMap[ruleDef.get("ruleSeq")] = ruleDef

            for appDef in getAppMetrisDef(db, appSysIds, defName, thresholdsMap):
                for appRuleDef in appDef["thresholds"]:
                    thresholdsMap[appRuleDef.get("ruleUuid", 0)] = appRuleDef
                    inspectRulesMap[appRuleDef.get("ruleSeq")] = appRuleDef

            rptData = {"RESOURCE_ID": nodeInfo["resourceId"], "MGMT_IP": nodeInfo["host"], "_execuser": execUser}

            # 抽取巡检定义里需要的属性字段到报告中
            needFields = objCatFieldsMap[objCat]
            for needField in needFields:
                if needField in dataCollected:
                    rptData[needField] = dataCollected[needField]

            alerts = []
            for inspectRule in thresholdsMap.values():
                ast = inspectRule["AST"]
                interpreter = Interpreter(AST=ast, ruleAppId=inspectRule.get("appSystemId"), ruleSeq=inspectRule["ruleSeq"], ruleName=inspectRule["name"], ruleLevel=inspectRule["level"], data=dataCollected)
                ruleAlerts = interpreter.resolve()
                alerts.extend(ruleAlerts)
                # alerts的数据结构
                # [
                #     {
                #         "jsonPath": "$.DISKS[0].CAPACITY",
                #         "ruleLevel": "WARN",
                #         "ruleAppId": 234315,
                #         "ruleSeq": "ABS#12",
                #         "ruleName": "测试",
                #         "fieldValue": 92,
                #     },
                #     {
                #         "jsonPath": "$.DISKS[1].CAPACITY",
                #         "ruleLevel": "WARN",
                #         "ruleAppId": 234315,
                #         "ruleSeq": "ABS#13",
                #         "ruleName": "测试",
                #         "fieldValue": 92,
                #     }
                # ]

            inspectResult = {"status": "NORMAL", "name": inspectDef["name"], "label": inspectDef["label"], "alerts": alerts}

            if "port" in nodeInfo:
                rptData["PORT"] = nodeInfo["port"]
            else:
                rptData["PORT"] = None

            alertThresHolds = {}
            if alerts:
                # 如果存在匹配了告警规则的rule，则存放到报告中，用于匹配规则的展示
                for alert in alerts:
                    ruleSeq = alert["ruleSeq"]
                    alertThresHolds[ruleSeq] = inspectRulesMap[ruleSeq]
                inspectResult["hasAlert"] = True
            else:
                inspectResult["hasAlert"] = False

            if not availability:
                inspectResult["hasAlert"] = True
                alerts.append({"jsonPath": "$.AVAILABILITY", "ruleLevel": "CRITICAL", "ruleSeq": "#01", "ruleName": "可用性告警", "fieldValue": availability})
                alertThresHolds["#01"] = {"name": "可用性告警", "level": "CRITICAL", "ruleSeq": "#01", "rule": "$.AVAILABILITY != True"}

                if rptData.get("ERROR_MESSAGE"):
                    alerts.append({"jsonPath": "$.ERROR_MESSAGE", "ruleLevel": "CRITICAL", "ruleName": "错误信息非空", "ruleSeq": "#02", "fieldValue": rptData.get("ERROR_MESSAGE")})
                    alertThresHolds["#02"] = {"name": "错误信息非空", "level": "CRITICAL", "ruleSeq": "#02", "rule": '$.ERROR_MESSAGE != ""'}

            inspectResult["thresholds"] = alertThresHolds

            alertSum = getAlertSummarized(alerts)
            alertSummarized = alertSum["alertSummarized"]
            inspectResult["alertOutline"] = alertSum["alertOutline"]
            inspectResult["alertFields"] = alertSum["alertFields"]
            inspectResult["totalCount"] = alertSummarized["totalCount"]
            inspectResult["status"] = alertSummarized["status"]
            rptData["_inspect_result"] = inspectResult

            rptDataRecords.append(rptData)
    return rptDataRecords


def saveReport(db, ciType, resourceId, rptData, statusTarget="inspect"):
    rptTable = "INSPECT_REPORTS"
    rptCollection = db[rptTable]

    currentTime = datetime.datetime.utcnow()
    rptData["_report_time"] = currentTime
    rptData["_jobid"] = os.getenv("AUTOEXEC_JOBID")
    primaryKey = {"RESOURCE_ID": rptData["RESOURCE_ID"]}

    rptCollection.replace_one(primaryKey, rptData, upsert=True)
    rptCollection.create_index([("RESOURCE_ID", 1)], name="idx_pk")
    rptCollection.create_index([("_jobid", 1)], name="idx_jobid")
    print("INFO: Save report data success.")
    saveHisReport(db, ciType, resourceId, rptData)
    if statusTarget in ("inspect", "both", "all"):
        AutoExecUtils.updateInspectStatus(ciType, resourceId, rptData["_inspect_result"]["status"], rptData["_inspect_result"]["totalCount"])
        print("INFO: Update inspect status for node success.")
    if statusTarget in ("monitor", "both", "all"):
        AutoExecUtils.updateMonitorStatus(ciType, resourceId, rptData["_inspect_result"]["status"], rptData["_inspect_result"]["totalCount"])
        print("INFO: Update monitor status for node success.")


def saveHisReport(db, ciType, resourceId, rptData):
    rptTable = "INSPECT_REPORTS_HIS"
    rptCollection = db[rptTable]

    rptCollection.insert_one(rptData)
    rptCollection.create_index([("RESOURCE_ID", 1), ("_report_time", 1)], name="idx_pk")
    # rptCollection.create_index([('_jobid', 1), ('_report_time', 1)], name='idx_jobid')
    rptCollection.create_index([("_report_time", 1)], name="idx_ttl", expireAfterSeconds=7776000)
    print("INFO: Save report history data success.")


# 反查资源中心resource_id


def getResourceId(jsonData):
    _OBJ_CATEGORY = jsonData["_OBJ_CATEGORY"]
    objType = jsonData["_OBJ_TYPE"]
    mgmtIp = jsonData["MGMT_IP"]
    port = jsonData["PORT"]

    name = jsonData.get("NAME")
    resourceId = None
    if (mgmtIp is None or mgmtIp) == "" and (name is None or name == ""):
        return resourceId
    resourceList = []
    if _OBJ_CATEGORY == "CONTAINER" or port is None:
        resourceList = AutoExecUtils.getResourceInfoList(mgmtIp, None, name, objType)
    else:
        resourceList = AutoExecUtils.getResourceInfoList(mgmtIp, port, name, objType)
    if len(resourceList) == 1:
        resourceId = resourceList[0]["id"]
    return resourceId


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--outputfile", default="", help="Output json file path for node")
    parser.add_argument("--node", default="", help="Execution node json")
    parser.add_argument("--statustarget", default="inspect", help="Inspect status save target, all|inspect|monitor")
    args = parser.parse_args()
    outputPath = args.outputfile
    node = args.node
    statusTarget = args.statustarget

    dbclient = None
    try:
        nodeInfo = {}
        hasOptError = False
        if node is None or node == "":
            node = os.getenv("AUTOEXEC_NODE")
        if node is None or node == "":
            print("ERROR: Can not find node definition.\n")
            hasOptError = True
        else:
            nodeInfo = json.loads(node)

        if outputPath is None or outputPath == "":
            outputPath = os.getenv("NODE_OUTPUT_PATH")
        else:
            os.environ["NODE_OUTPUT_PATH"] = outputPath

        if outputPath is None:
            print("ERROR: Must set environment variable NODE_OUTPUT_PATH or defined option --outputfile.\n")
            hasOptError = True

        if hasOptError:
            usage()

        ciType = nodeInfo["nodeType"]
        resourceId = nodeInfo["resourceId"]
        (dbclient, db) = AutoExecUtils.getDB()

        rptDataRecords = inspectData(nodeInfo)
        for rptData in rptDataRecords:
            saveReport(db, ciType, resourceId, rptData, statusTarget)

    except Exception as ex:
        if "password" in nodeInfo:
            nodeInfo["password"] = "******"
        print("ERROR: Unknow Error for node {}, {}".format(nodeInfo, traceback.format_exc()))
        exit(-1)
    finally:
        if dbclient is not None:
            dbclient.close()
